{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "## 1.input txt"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "15c22a3a9eba142c"
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-05-30T07:15:11.364650100Z",
     "start_time": "2024-05-30T07:15:11.356234500Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "length of this dataset is: 1115394\n"
     ]
    }
   ],
   "source": [
    "with open('./input.txt', 'r', encoding='utf-8') as f:\n",
    "    text = f.read()\n",
    "print('length of this dataset in characters:', len(text))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n",
      "65\n"
     ]
    }
   ],
   "source": [
    "# here are all the unique characters that occur in this text\n",
    "chars = sorted(list(set(text)))\n",
    "vocab_size = len(chars)\n",
    "print(''.join(chars))\n",
    "print(vocab_size)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-30T07:18:46.624794200Z",
     "start_time": "2024-05-30T07:18:46.573257800Z"
    }
   },
   "id": "2adde87443f5cba5"
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[46, 47, 47, 1, 58, 46, 43, 56, 43]\n",
      "hii there\n"
     ]
    }
   ],
   "source": [
    "# create a mapping from characters to integers\n",
    "stoi = { ch:i for i,ch in enumerate(chars) }\n",
    "itos = { i:ch for i,ch in enumerate(chars) }\n",
    "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n",
    "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n",
    "\n",
    "print(encode(\"hii there\"))\n",
    "print(decode(encode(\"hii there\")))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-30T07:19:39.136672900Z",
     "start_time": "2024-05-30T07:19:39.111144Z"
    }
   },
   "id": "ee9404107273b9df"
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1115394]) torch.int64\n",
      "tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "# let's now encode the entire text dataset and store it into a torch.Tensor\n",
    "data = torch.tensor(encode(text), dtype=torch.long)\n",
    "print(data.shape, data.dtype)\n",
    "print(data[:10]) # the 10 characters we looked at earier will to the GPT look like this"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-30T07:22:22.818107600Z",
     "start_time": "2024-05-30T07:22:15.360250100Z"
    }
   },
   "id": "eb92f02df524ab34"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Let's now split up the data into train and validation sets\n",
    "n = int(0.9*len(data)) # first 90% will be train, rest val\n",
    "train_data = data[:n]\n",
    "val_data = data[n:]"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "108fc2324b896738"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 2.理解self attention"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "9ae328ff1e03b84b"
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [
    {
     "data": {
      "text/plain": "torch.Size([4, 8, 16])"
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch.nn as nn\n",
    "from torch.nn import functional as F\n",
    "torch.manual_seed(1337)\n",
    "B,T,C = 4,8,32 # batch, time, channels\n",
    "x = torch.randn(B,T,C)\n",
    "\n",
    "# let's see a single Head perform self-attention\n",
    "head_size = 16\n",
    "key = nn.Linear(C, head_size, bias=False)\n",
    "query = nn.Linear(C, head_size, bias=False)\n",
    "value = nn.Linear(C, head_size, bias=False)\n",
    "k = key(x)   # (B, T, 16)\n",
    "q = query(x) # (B, T, 16)\n",
    "wei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)\n",
    "\n",
    "tril = torch.tril(torch.ones(T, T))\n",
    "#wei = torch.zeros((T,T))\n",
    "wei = wei.masked_fill(tril == 0, float('-inf'))\n",
    "wei = F.softmax(wei, dim=-1) #B,T,T\n",
    "\n",
    "v = value(x) #B,T,H\n",
    "out = wei @ v # (B,T,T) @ (B,T,H) -> (B,T,H)\n",
    "#The output from this final matrix product is subsequently passsed through a linear layer as shown in the diagram above\n",
    "\n",
    "out.shape"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-30T07:39:45.138440300Z",
     "start_time": "2024-05-30T07:39:44.610353300Z"
    }
   },
   "id": "8b1cd9ad34576f4c"
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [],
   "source": [
    "#Causal scaled dot product self-Attention Head\n",
    "\n",
    "n_embd = 64\n",
    "n_head = 4\n",
    "n_layer = 4\n",
    "head_size = 16\n",
    "dropout = 0.1\n",
    "block_size = 8\n",
    "\n",
    "class Head(nn.Module):\n",
    "    \"\"\" one head of self-attention \"\"\"\n",
    "\n",
    "    def __init__(self, head_size):\n",
    "        super().__init__()\n",
    "        self.key = nn.Linear(n_embd, head_size, bias=False)\n",
    "        self.query = nn.Linear(n_embd, head_size, bias=False)\n",
    "        self.value = nn.Linear(n_embd, head_size, bias=False)\n",
    "        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
    "\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "    def forward(self, x):\n",
    "        B,T,C = x.shape\n",
    "        k = self.key(x)   # (B,T,C)\n",
    "        q = self.query(x) # (B,T,C)\n",
    "        # compute attention scores (\"affinities\")\n",
    "        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n",
    "        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n",
    "        wei = F.softmax(wei, dim=-1) # (B, T, T)\n",
    "        wei = self.dropout(wei)\n",
    "        # perform the weighted aggregation of the values\n",
    "        v = self.value(x) # (B,T,C)\n",
    "        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n",
    "        return out"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-30T07:45:34.035947600Z",
     "start_time": "2024-05-30T07:45:34.020429300Z"
    }
   },
   "id": "99d6006746924b81"
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [],
   "source": [
    "#Multi-Headed Self Attention\n",
    "class MultiHeadAttention(nn.Module):\n",
    "    \"\"\" multiple heads of self-attention in parallel \"\"\"\n",
    "\n",
    "    def __init__(self, num_heads, head_size):\n",
    "        super().__init__()\n",
    "        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n",
    "        self.proj = nn.Linear(n_embd, n_embd)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = torch.cat([h(x) for h in self.heads], dim=-1)\n",
    "        out = self.dropout(self.proj(out))\n",
    "        return out\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-30T07:45:36.599384100Z",
     "start_time": "2024-05-30T07:45:36.583655900Z"
    }
   },
   "id": "1ab2037a0177da2e"
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [
    {
     "data": {
      "text/plain": "torch.Size([4, 8, 64])"
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#Confirming that what's output from multi head attention is the original embedding size\n",
    "B,T,C = 4,8,64 # batch, time, channels\n",
    "x = torch.randn(B,T,C)\n",
    "mha = MultiHeadAttention(4,16)\n",
    "mha(x).shape"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-30T07:45:48.079379800Z",
     "start_time": "2024-05-30T07:45:47.835776700Z"
    }
   },
   "id": "6b24076b953aa781"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 3.创建专家模块 (即一个简单的MLP多层感知机)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "9f382de962aa6e6f"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "#Expert module\n",
    "class Expert(nn.Module):\n",
    "    \"\"\" An MLP is a simple linear layer followed by a non-linearity i.e. each Expert \"\"\"\n",
    "\n",
    "    def __init__(self, n_embd):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(n_embd, 4 * n_embd),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(4 * n_embd, n_embd),\n",
    "            nn.Dropout(dropout),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "e22352309b2213a8"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 4. 创建TopkRouter"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "774e25b91e183ade"
  },
  {
   "cell_type": "markdown",
   "source": [
    "假设我们定义了4个专家，路由取前2名专家。接收注意力层的输出作为输入，即将输入从（Batch size，Tokens，n_embed）的形状（2，4，32）投影到对应于（Batch size，Tokens，num_experts）的形状（2，4，4），其中num_experts是专家网络的计数。其中返回的indices可以理解为对于每个token的4个专家来说，选的两个专家的序号索引。"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "4e7e8e0a2de464fd"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 这里我们假设定义n_embed为32， num_experts=4, top_k=2\n",
    "\n",
    "class TopkRouter(nn.Module):\n",
    "    def __init__(self, n_embed, num_experts, top_k):\n",
    "        super(TopkRouter, self).__init__()\n",
    "        self.top_k = top_k\n",
    "        self.linear =nn.Linear(n_embed, num_experts)\n",
    "    \n",
    "    def forward(self, mh_output):\n",
    "        logits = self.linear(mh_output)    # (2,4,32) ---> (2,4,4)\n",
    "        # 获取前K大的值和索引，沿列。\n",
    "        top_k_logits, indices = logits.topk(self.top_k, dim=-1)\n",
    "        # 创建一个形状和logits相同全'-inf'矩阵，即(2,4,4)\n",
    "        zeros = torch.full_like(logits, float('-inf'))\n",
    "        # 按照索引和值填充上述zeros矩阵\n",
    "        sparse_logits = zeros.scatter(-1, indices, top_k_logits)\n",
    "        # 对其进行softmax，未被填充的位置会为0\n",
    "        router_output = F.softmax(sparse_logits, dim=-1)\n",
    "        return router_output, indices"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "87c76b2707bc8c96"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 5.添加噪声路由"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "bfe6884a60d57b7f"
  },
  {
   "cell_type": "markdown",
   "source": [
    "从本质上讲，我们不希望所有token都发送给同一组“受青睐”的expert。需要一个良好平衡，因此，将标准正态噪声添加到来自门控线性层的logits。\n",
    "代码对比上面的正常router的代码只改动了几行，非常的简单。"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "ebcc3d100b88a6b6"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class NoisyTopkRouter(nn.Module):\n",
    "    def __init__(self, n_embed, num_experts, top_k):\n",
    "        super(NoisyTopkRouter, self).__init__()\n",
    "        self.top_k = top_k\n",
    "        self.topkroute_linear = nn.Linear(n_embed, num_experts)\n",
    "        # add noise\n",
    "        self.noise_linear =nn.Linear(n_embed, num_experts)\n",
    "\n",
    "    \n",
    "    def forward(self, mh_output):\n",
    "        logits = self.topkroute_linear(mh_output)\n",
    "\n",
    "        # Noise logits\n",
    "        noise_logits = self.noise_linear(mh_output)\n",
    "\n",
    "        # Adding scaled unit gaussian noise to the logits\n",
    "        noise = torch.randn_like(logits)*F.softplus(noise_logits)\n",
    "        noisy_logits = logits + noise\n",
    "\n",
    "        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)\n",
    "        zeros = torch.full_like(noisy_logits, float('-inf'))\n",
    "        sparse_logits = zeros.scatter(-1, indices, top_k_logits)\n",
    "        router_output = F.softmax(sparse_logits, dim=-1)\n",
    "        return router_output, indices"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "1ae6c4d0980dab48"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 6. 将上述结合，最终构建稀疏MOE"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "d88622179f5a6939"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class SparseMoE(nn.Module):\n",
    "    def __init__(self, n_embed, num_experts, top_k):\n",
    "        super(SparseMoE, self).__init__()\n",
    "        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)\n",
    "        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])\n",
    "        self.top_k = top_k\n",
    "\n",
    "    def forward(self, x):\n",
    "        # 1. 输入进入router得到两个输出\n",
    "        gating_output, indices = self.router(x)\n",
    "        # 2.初始化全零矩阵，后续叠加为最终结果\n",
    "        final_output = torch.zeros_like(x)\n",
    "\n",
    "        # 3.展平，即把每个batch拼接到一起，这里对输入x和router后的结果都进行了展平\n",
    "        flat_x = x.view(-1, x.size(-1))\n",
    "        flat_gating_output = gating_output.view(-1, gating_output.size(-1))\n",
    "\n",
    "        # 以每个专家为单位进行操作，即把当前专家处理的所有token都进行加权\n",
    "        for i, expert in enumerate(self.experts):\n",
    "            # 4. 对当前的专家(例如专家0)来说，查看其对所有tokens中哪些在前top2\n",
    "            expert_mask = (indices == i).any(dim=-1)\n",
    "            # 5. 展平操作\n",
    "            flat_mask = expert_mask.view(-1)\n",
    "            # 如果当前专家是任意一个token的前top2\n",
    "            if flat_mask.any():\n",
    "                # 6. 得到该专家对哪几个token起作用后，选取token的维度表示\n",
    "                expert_input = flat_x[flat_mask]\n",
    "                # 7. 将token输入expert得到输出\n",
    "                expert_output = expert(expert_input)\n",
    "\n",
    "                # 8. 计算当前专家对于有作用的token的权重分数\n",
    "                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)\n",
    "                # 9. 将expert输出乘上权重分数\n",
    "                weighted_output = expert_output * gating_scores\n",
    "\n",
    "                # 10. 循环进行做种的结果叠加\n",
    "                final_output[expert_mask] += weighted_output.squeeze(1)\n",
    "\n",
    "        return final_output\n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "ab0e6d83fd0a9364"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
